#NTIRE 2021 Realtime Image Super Resolution Challenge
import Utils_MAI as u
import numpy as np
import tensorflow as tf
from SRModels_MAI import *
import keras.models as km
import keras.layers as kl
import keras.backend as kb
import keras.optimizers as ko
import keras

load_model = False  #Set True to start from a pretrained model
train = True #if this is False uses the pretrained model to calculate the PSNR on validation set

bs = 16
downscale = 3
image_w = 32 * downscale
image_h = 32 * downscale
num_epochs = 5000

model_name = "../Model/best_model.h5"

dataset_folder_hr = "../Other/DIV2K_train_HR"
dataset_folder_lr = "../Other/DIV2K_train_LR_bicubic_X3"

percent = 0.99
if train:
    percent = 0.99
    training_images = u.TrainingDataset(dataset_folder_hr, dataset_folder_lr, bs, image_w, image_h, downscale,percent, 100)

validation_images = u.ValidationDataset(dataset_folder_hr, dataset_folder_lr, downscale, 1-percent)

mses = [np.mean(np.power(u.shave(validation_images[img][0][2][0], downscale) - u.shave(1*(validation_images[img][1][0]+0), downscale), 2)) for img in range(len(validation_images))]

mean_psnr = np.mean(10 * np.log10(1 / np.array(mses)))
print("Validation Set (Bicubic Interpolation) MEAN PSNR %f dB" % mean_psnr)


cross_tower_ops = tf.distribute.HierarchicalCopyAllReduce()
devices1 = ["/gpu:0"]
devices2 = ["/gpu:0", "/gpu:1"]
devices3 = ["/gpu:0", "/gpu:1", "/gpu:2"]
devices4 = ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"]
strategy = tf.distribute.MirroredStrategy(devices=devices4, cross_device_ops=cross_tower_ops)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

if train:
    with strategy.scope():
        model = Uint8ModelTestv34(downscale).network()
        if load_model:
            with keras.utils.custom_object_scope({'custom_mse': u.custom_mse, "custom_psnr": u.custom_psnr, "tf": tf, "kb": kb}):
                model = km.load_model(model_name)

        model.summary()

        hr_image = kl.Input(shape=(None, None, 3), name="HR_image")
        sr_image = kl.Input(shape=(None, None, 3), name="SR_image")
        combined_model = km.Model([model.input, hr_image, sr_image], model.output, name="combined")

        hr_image_estimate = combined_model.output

        combined_model.add_loss(u.custom_charbon(hr_image, hr_image_estimate))
        combined_model.add_metric(u.custom_psnr(u.keras_shave(kb.clip(hr_image,0,1),downscale), u.keras_shave(kb.clip(hr_image_estimate,0,1),downscale)), name="custom_psnr")

        optim = ko.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

        combined_model.compile(optimizer=optim)
        combined_model.summary()

    lr_warmer = u.WarmUpSchedulerPerEpoch(0.0025, 0.0001, 50, num_epochs)

    model_saver = u.ModelSaveOnEpochEnd(model, model_name, False, validation_images, downscale)
    callback_list = [lr_warmer, model_saver]
    combined_model.fit(training_images, epochs=num_epochs, validation_data=validation_images, validation_batch_size=1,  callbacks=callback_list)

    with keras.utils.custom_object_scope({'custom_mse': u.custom_mse, "custom_psnr": u.custom_psnr, "tf": tf, "kb": kb}):
        model = km.load_model(model_name)

    u.dataset_prediction(model, validation_images, downscale)

else:

    with keras.utils.custom_object_scope({'custom_mse': u.custom_mse, "custom_psnr": u.custom_psnr, "tf": tf, "kb": kb}):
        model = km.load_model(model_name)

    u.dataset_prediction(model, validation_images, downscale)
